{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<i>Copyright (c) Recommenders contributors.</i>\n",
    "\n",
    "<i>Licensed under the MIT License.</i>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## EmbeddingDotBias Recommender\n",
    "\n",
    "This notebook shows how to use `EmbeddingDotBias` similar to [EmbeddingDotBias](https://docs.fast.ai/collab.html#embeddingdotbias) from FastAI but directly using Pytorch. This will create an embedding for the users and the items."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "System version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]\n",
      "Pandas version: 2.2.2\n",
      "PyTorch version: 2.3.1+cu121\n",
      "CUDA Available: True\n",
      "CuDNN Enabled: True\n"
     ]
    }
   ],
   "source": [
    "# Suppress all warnings\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import logging\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from tempfile import TemporaryDirectory\n",
    "\n",
    "from recommenders.utils.constants import (\n",
    "    DEFAULT_USER_COL as USER, \n",
    "    DEFAULT_ITEM_COL as ITEM, \n",
    "    DEFAULT_RATING_COL as RATING, \n",
    "    DEFAULT_TIMESTAMP_COL as TIMESTAMP, \n",
    "    DEFAULT_PREDICTION_COL as PREDICTION\n",
    ")\n",
    "\n",
    "from recommenders.datasets import movielens\n",
    "from recommenders.datasets.python_splitters import python_stratified_split\n",
    "from recommenders.evaluation.python_evaluation import (exp_var, mae, map,\n",
    "                                                       ndcg_at_k,\n",
    "                                                       precision_at_k,\n",
    "                                                       recall_at_k, rmse,\n",
    "                                                       rsquared)\n",
    "from recommenders.models.embdotbias.data_loader import RecoDataLoader\n",
    "from recommenders.models.embdotbias.model import EmbeddingDotBias\n",
    "from recommenders.models.embdotbias.training_utils import (Trainer,\n",
    "                                                           predict_rating)\n",
    "from recommenders.models.embdotbias.utils import cartesian_product, score\n",
    "from recommenders.utils.notebook_utils import store_metadata\n",
    "from recommenders.utils.timer import Timer\n",
    "\n",
    "logging.basicConfig(level=logging.INFO, format=\"%(levelname)s - %(message)s\")\n",
    "\n",
    "print(f\"System version: {sys.version}\")\n",
    "print(f\"Pandas version: {pd.__version__}\")\n",
    "print(f\"PyTorch version: {torch.__version__}\")\n",
    "print(f\"CUDA Available: {torch.cuda.is_available()}\")\n",
    "print(f\"CuDNN Enabled: {torch.backends.cudnn.enabled}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Defining some constants to refer to the different columns of our dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# top k items to recommend\n",
    "TOP_K = 10\n",
    "\n",
    "# Select MovieLens data size: 100k, 1m, 10m, or 20m\n",
    "MOVIELENS_DATA_SIZE = \"100k\"\n",
    "\n",
    "# Model parameters\n",
    "N_FACTORS = 40\n",
    "EPOCHS = 7\n",
    "SEED = 101"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO - Downloading https://files.grouplens.org/datasets/movielens/ml-100k.zip\n",
      "100%|██████████| 4.81k/4.81k [00:01<00:00, 3.56kKB/s]\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userID</th>\n",
       "      <th>itemID</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>196</td>\n",
       "      <td>242</td>\n",
       "      <td>3.0</td>\n",
       "      <td>881250949</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>186</td>\n",
       "      <td>302</td>\n",
       "      <td>3.0</td>\n",
       "      <td>891717742</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>22</td>\n",
       "      <td>377</td>\n",
       "      <td>1.0</td>\n",
       "      <td>878887116</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>244</td>\n",
       "      <td>51</td>\n",
       "      <td>2.0</td>\n",
       "      <td>880606923</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>166</td>\n",
       "      <td>346</td>\n",
       "      <td>1.0</td>\n",
       "      <td>886397596</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  userID itemID  rating  timestamp\n",
       "0    196    242     3.0  881250949\n",
       "1    186    302     3.0  891717742\n",
       "2     22    377     1.0  878887116\n",
       "3    244     51     2.0  880606923\n",
       "4    166    346     1.0  886397596"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ratings_df = movielens.load_pandas_df(\n",
    "    size=MOVIELENS_DATA_SIZE,\n",
    "    header=[USER,ITEM,RATING,TIMESTAMP]\n",
    ")\n",
    "\n",
    "# Make sure the IDs are loaded as strings to better prevent confusion with embedding ids\n",
    "ratings_df[USER] = ratings_df[USER].astype(\"str\")\n",
    "ratings_df[ITEM] = ratings_df[ITEM].astype(\"str\")\n",
    "\n",
    "ratings_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the dataset\n",
    "train_valid_df, test_df = python_stratified_split(\n",
    "    ratings_df,\n",
    "    ratio=0.75, \n",
    "    min_rating=1, \n",
    "    filter_by=\"item\", \n",
    "    col_user=USER, \n",
    "    col_item=ITEM,\n",
    "    seed=SEED\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userID</th>\n",
       "      <th>itemID</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>99941</th>\n",
       "      <td>593</td>\n",
       "      <td>1</td>\n",
       "      <td>3.0</td>\n",
       "      <td>875659150</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>63031</th>\n",
       "      <td>879</td>\n",
       "      <td>1</td>\n",
       "      <td>4.0</td>\n",
       "      <td>887761865</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>66516</th>\n",
       "      <td>216</td>\n",
       "      <td>1</td>\n",
       "      <td>4.0</td>\n",
       "      <td>880232615</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21048</th>\n",
       "      <td>200</td>\n",
       "      <td>1</td>\n",
       "      <td>5.0</td>\n",
       "      <td>876042340</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>78925</th>\n",
       "      <td>933</td>\n",
       "      <td>1</td>\n",
       "      <td>3.0</td>\n",
       "      <td>874854294</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10413</th>\n",
       "      <td>336</td>\n",
       "      <td>999</td>\n",
       "      <td>2.0</td>\n",
       "      <td>877757516</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7847</th>\n",
       "      <td>125</td>\n",
       "      <td>999</td>\n",
       "      <td>4.0</td>\n",
       "      <td>892838288</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34637</th>\n",
       "      <td>417</td>\n",
       "      <td>999</td>\n",
       "      <td>3.0</td>\n",
       "      <td>880952434</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>42623</th>\n",
       "      <td>476</td>\n",
       "      <td>999</td>\n",
       "      <td>2.0</td>\n",
       "      <td>883365385</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98226</th>\n",
       "      <td>682</td>\n",
       "      <td>999</td>\n",
       "      <td>2.0</td>\n",
       "      <td>888521942</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>75066 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      userID itemID  rating  timestamp\n",
       "99941    593      1     3.0  875659150\n",
       "63031    879      1     4.0  887761865\n",
       "66516    216      1     4.0  880232615\n",
       "21048    200      1     5.0  876042340\n",
       "78925    933      1     3.0  874854294\n",
       "...      ...    ...     ...        ...\n",
       "10413    336    999     2.0  877757516\n",
       "7847     125    999     4.0  892838288\n",
       "34637    417    999     3.0  880952434\n",
       "42623    476    999     2.0  883365385\n",
       "98226    682    999     2.0  888521942\n",
       "\n",
       "[75066 rows x 4 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_valid_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove \"cold\" users from test set \n",
    "test_df = test_df[test_df[USER].isin(train_valid_df[USER])]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fix random seeds to make sure the runs are reproducible\n",
    "np.random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.cuda.manual_seed_all(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = RecoDataLoader.from_df(\n",
    "    train_valid_df,\n",
    "    user_name=USER,\n",
    "    item_name=ITEM,\n",
    "    rating_name=RATING,\n",
    "    valid_pct=0.1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Showing a sample batch:\n",
      "Showing 5 examples from a batch:\n",
      "  userID itemID  rating\n",
      "0    710    302     4.0\n",
      "1    588    554     3.0\n",
      "2     92    452     2.0\n",
      "3    727     56     3.0\n",
      "4    535    212     4.0\n"
     ]
    }
   ],
   "source": [
    "data.show_batch()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will be using 40 latent factors. This will create an embedding for the users and the items that will map each of these to 40 floats as can be seen below. Note that the embedding parameters are not predefined, but are learned by the model.\n",
    "\n",
    "Although ratings can only range from 1-5, we are setting the range of possible ratings to a range from 0 to 5.5 -- that will allow the model to predict values around 1 and 5, which improves accuracy. Lastly, we set a value for weight-decay for regularization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = EmbeddingDotBias.from_classes(\n",
    "    n_factors=N_FACTORS,\n",
    "    classes=data.classes,\n",
    "    user=USER,\n",
    "    item=ITEM,\n",
    "    y_range=[0,5.5]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now train the model for 7 epochs setting the maximal learning rate. The learner will reduce the learning rate with each epoch using cosine annealing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO - Epoch 1/7:\n",
      "INFO - Train Loss: 1.3875741174613887\n",
      "INFO - Valid Loss: 1.0270110111115343\n",
      "INFO - Train Loss: 1.3875741174613887\n",
      "INFO - Valid Loss: 1.0270110111115343\n",
      "INFO - Epoch 2/7:\n",
      "INFO - Train Loss: 0.908381488456419\n",
      "INFO - Valid Loss: 0.9222675213369272\n",
      "INFO - Epoch 2/7:\n",
      "INFO - Train Loss: 0.908381488456419\n",
      "INFO - Valid Loss: 0.9222675213369272\n",
      "INFO - Epoch 3/7:\n",
      "INFO - Train Loss: 0.821684703202636\n",
      "INFO - Valid Loss: 0.8861896385580806\n",
      "INFO - Epoch 3/7:\n",
      "INFO - Train Loss: 0.821684703202636\n",
      "INFO - Valid Loss: 0.8861896385580806\n",
      "INFO - Epoch 4/7:\n",
      "INFO - Train Loss: 0.7628276834941723\n",
      "INFO - Valid Loss: 0.8663221562312822\n",
      "INFO - Epoch 4/7:\n",
      "INFO - Train Loss: 0.7628276834941723\n",
      "INFO - Valid Loss: 0.8663221562312822\n",
      "INFO - Epoch 5/7:\n",
      "INFO - Train Loss: 0.7107005440488909\n",
      "INFO - Valid Loss: 0.8576887482303684\n",
      "INFO - Epoch 5/7:\n",
      "INFO - Train Loss: 0.7107005440488909\n",
      "INFO - Valid Loss: 0.8576887482303684\n",
      "INFO - Epoch 6/7:\n",
      "INFO - Train Loss: 0.6560591047234607\n",
      "INFO - Epoch 6/7:\n",
      "INFO - Train Loss: 0.6560591047234607\n",
      "INFO - Valid Loss: 0.8523229274709346\n",
      "INFO - Valid Loss: 0.8523229274709346\n",
      "INFO - Epoch 7/7:\n",
      "INFO - Train Loss: 0.5980674705439897\n",
      "INFO - Valid Loss: 0.8517736125800569\n",
      "INFO - Epoch 7/7:\n",
      "INFO - Train Loss: 0.5980674705439897\n",
      "INFO - Valid Loss: 0.8517736125800569\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Took 63.6546 seconds for training.\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(model=model)\n",
    "\n",
    "with Timer() as train_time:\n",
    "    trainer.fit(data.train, data.valid, EPOCHS)\n",
    "\n",
    "print(f\"Took {train_time} seconds for training.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the learner so it can be loaded back later for inferencing / generating recommendations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved to: /tmp/tmp3det9ii6/embdotbias_model.pth\n"
     ]
    }
   ],
   "source": [
    "tmp = TemporaryDirectory()\n",
    "model_path = os.path.join(tmp.name, \"embdotbias_model.pth\")\n",
    "\n",
    "torch.save(model.state_dict(), model_path)\n",
    "print(f\"Model saved to: {model_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating Recommendations\n",
    "\n",
    "Load the learner from disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model loaded successfully.\n"
     ]
    }
   ],
   "source": [
    "loaded_model = EmbeddingDotBias.from_classes(\n",
    "    n_factors=N_FACTORS, \n",
    "    classes=data.classes, \n",
    "    user=USER,\n",
    "    item=ITEM,\n",
    "    y_range=[0,5.5] \n",
    ")\n",
    "\n",
    "# Load the state dictionary\n",
    "loaded_model.load_state_dict(torch.load(model_path))\n",
    "\n",
    "# Set the model to evaluation mode\n",
    "loaded_model.eval()\n",
    "\n",
    "print(\"Model loaded successfully.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get all users and items that the model knows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Total items & users\n",
    "total_items = loaded_model.classes[ITEM][1:]\n",
    "total_users = loaded_model.classes[USER][1:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get all users from the test set and remove any users that were not known in the training set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_users = test_df[USER].unique()\n",
    "test_users = np.intersect1d(test_users, total_users)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Example prediction\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "User ID: 864, Item ID: 232\n"
     ]
    }
   ],
   "source": [
    "first_batch = next(iter(data.train))\n",
    "user_idx = first_batch[0][0, 0].item()  \n",
    "user_id = data.classes[USER][user_idx]  \n",
    "item_idx = first_batch[0][0, 1].item() \n",
    "item_id = data.classes[ITEM][item_idx]  \n",
    "print(f\"User ID: {user_id}, Item ID: {item_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted rating for user 864 and item 232: 3.881427526473999\n"
     ]
    }
   ],
   "source": [
    "\n",
    "try: \n",
    "    user_embeddings = loaded_model.weight([user_id, item_id], is_item=False)\n",
    "    predicted_rating = predict_rating(loaded_model, user_id, item_id)\n",
    "    print(f\"Predicted rating for user {user_id} and item {item_id}: {predicted_rating}\")\n",
    "except KeyError as e:\n",
    "    print(f\"Error: {e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build the cartesian product of test set users and all items known to the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "users_items = cartesian_product(np.array(test_users),np.array(total_items))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "users_items = pd.DataFrame(users_items, columns=[USER,ITEM])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userID</th>\n",
       "      <th>itemID</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>1001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1586121</th>\n",
       "      <td>99</td>\n",
       "      <td>995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1586122</th>\n",
       "      <td>99</td>\n",
       "      <td>996</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1586123</th>\n",
       "      <td>99</td>\n",
       "      <td>997</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1586124</th>\n",
       "      <td>99</td>\n",
       "      <td>998</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1586125</th>\n",
       "      <td>99</td>\n",
       "      <td>999</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1586126 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        userID itemID\n",
       "0            1      1\n",
       "1            1     10\n",
       "2            1    100\n",
       "3            1   1000\n",
       "4            1   1001\n",
       "...        ...    ...\n",
       "1586121     99    995\n",
       "1586122     99    996\n",
       "1586123     99    997\n",
       "1586124     99    998\n",
       "1586125     99    999\n",
       "\n",
       "[1586126 rows x 2 columns]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "users_items"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Lastly, remove the user/items combinations that are in the training set -- we don't want to propose a movie that the user has already watched."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "users_items_candidates = pd.merge(users_items, train_valid_df.astype(str), on=[USER, ITEM], how=\"left\")\n",
    "users_items_candidates = users_items_candidates[users_items_candidates[RATING].isna()][[USER, ITEM]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userID</th>\n",
       "      <th>itemID</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>1001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>1002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1</td>\n",
       "      <td>1003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1</td>\n",
       "      <td>1004</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1586121</th>\n",
       "      <td>99</td>\n",
       "      <td>995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1586122</th>\n",
       "      <td>99</td>\n",
       "      <td>996</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1586123</th>\n",
       "      <td>99</td>\n",
       "      <td>997</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1586124</th>\n",
       "      <td>99</td>\n",
       "      <td>998</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1586125</th>\n",
       "      <td>99</td>\n",
       "      <td>999</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1511060 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        userID itemID\n",
       "3            1   1000\n",
       "4            1   1001\n",
       "5            1   1002\n",
       "6            1   1003\n",
       "7            1   1004\n",
       "...        ...    ...\n",
       "1586121     99    995\n",
       "1586122     99    996\n",
       "1586123     99    997\n",
       "1586124     99    998\n",
       "1586125     99    999\n",
       "\n",
       "[1511060 rows x 2 columns]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "users_items_candidates"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Score the model to find the top K recommendation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_k_scores = score(\n",
    "    loaded_model, \n",
    "    test_df=users_items_candidates,\n",
    "    user_col=USER,\n",
    "    item_col=ITEM,\n",
    "    prediction_col=PREDICTION\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userID</th>\n",
       "      <th>itemID</th>\n",
       "      <th>prediction</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1642</th>\n",
       "      <td>1</td>\n",
       "      <td>963</td>\n",
       "      <td>5.101374</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1109</th>\n",
       "      <td>1</td>\n",
       "      <td>483</td>\n",
       "      <td>5.003863</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1026</th>\n",
       "      <td>1</td>\n",
       "      <td>408</td>\n",
       "      <td>4.969304</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>780</th>\n",
       "      <td>1</td>\n",
       "      <td>187</td>\n",
       "      <td>4.891338</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1143</th>\n",
       "      <td>1</td>\n",
       "      <td>513</td>\n",
       "      <td>4.880493</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1584764</th>\n",
       "      <td>99</td>\n",
       "      <td>1287</td>\n",
       "      <td>1.850188</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1585974</th>\n",
       "      <td>99</td>\n",
       "      <td>862</td>\n",
       "      <td>1.774681</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1585488</th>\n",
       "      <td>99</td>\n",
       "      <td>424</td>\n",
       "      <td>1.739392</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1586100</th>\n",
       "      <td>99</td>\n",
       "      <td>976</td>\n",
       "      <td>1.690039</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1585506</th>\n",
       "      <td>99</td>\n",
       "      <td>440</td>\n",
       "      <td>1.631488</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1511060 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        userID itemID  prediction\n",
       "1642         1    963    5.101374\n",
       "1109         1    483    5.003863\n",
       "1026         1    408    4.969304\n",
       "780          1    187    4.891338\n",
       "1143         1    513    4.880493\n",
       "...        ...    ...         ...\n",
       "1584764     99   1287    1.850188\n",
       "1585974     99    862    1.774681\n",
       "1585488     99    424    1.739392\n",
       "1586100     99    976    1.690039\n",
       "1585506     99    440    1.631488\n",
       "\n",
       "[1511060 rows x 3 columns]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top_k_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculate some metrics for our model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_map = map(test_df, top_k_scores, col_user=USER, col_item=ITEM, \n",
    "               col_rating=RATING, col_prediction=PREDICTION, \n",
    "               relevancy_method=\"top_k\", k=TOP_K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_ndcg = ndcg_at_k(test_df, top_k_scores, col_user=USER, col_item=ITEM, \n",
    "                      col_rating=RATING, col_prediction=PREDICTION, \n",
    "                      relevancy_method=\"top_k\", k=TOP_K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_precision = precision_at_k(test_df, top_k_scores, col_user=USER, col_item=ITEM, \n",
    "                                col_rating=RATING, col_prediction=PREDICTION, \n",
    "                                relevancy_method=\"top_k\", k=TOP_K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_recall = recall_at_k(test_df, top_k_scores, col_user=USER, col_item=ITEM, \n",
    "                          col_rating=RATING, col_prediction=PREDICTION, \n",
    "                          relevancy_method=\"top_k\", k=TOP_K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model:\t\tEmbeddingDotBias\n",
      "Top K:\t\t10\n",
      "MAP:\t\t0.020839\n",
      "NDCG:\t\t0.131409\n",
      "Precision@K:\t0.121633\n",
      "Recall@K:\t0.047912\n"
     ]
    }
   ],
   "source": [
    "print(\"Model:\\t\\t\" + model.__class__.__name__,\n",
    "      \"Top K:\\t\\t%d\" % TOP_K,\n",
    "      \"MAP:\\t\\t%f\" % eval_map,\n",
    "      \"NDCG:\\t\\t%f\" % eval_ndcg,\n",
    "      \"Precision@K:\\t%f\" % eval_precision,\n",
    "      \"Recall@K:\\t%f\" % eval_recall, sep='\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above numbers are lower than [SAR](../sar_single_node_movielens.ipynb), but expected, since the model is explicitly trying to generalize the users and items to the latent factors. Next look at how well the model predicts how the user would rate the movie. Need to score `test_df` user-items only. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = score(\n",
    "    model,\n",
    "    test_df=test_df, \n",
    "    user_col=USER, \n",
    "    item_col=ITEM, \n",
    "    prediction_col=PREDICTION\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now calculate some regression metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model:\t\t\tEmbeddingDotBias\n",
      "RMSE:\t\t\t0.910456\n",
      "MAE:\t\t\t0.713525\n",
      "Explained variance:\t0.339586\n",
      "R squared:\t\t0.339563\n"
     ]
    }
   ],
   "source": [
    "eval_r2 = rsquared(test_df, scores, col_user=USER, col_item=ITEM, col_rating=RATING, col_prediction=PREDICTION)\n",
    "eval_rmse = rmse(test_df, scores, col_user=USER, col_item=ITEM, col_rating=RATING, col_prediction=PREDICTION)\n",
    "eval_mae = mae(test_df, scores, col_user=USER, col_item=ITEM, col_rating=RATING, col_prediction=PREDICTION)\n",
    "eval_exp_var = exp_var(test_df, scores, col_user=USER, col_item=ITEM, col_rating=RATING, col_prediction=PREDICTION)\n",
    "\n",
    "print(\"Model:\\t\\t\\t\" + model.__class__.__name__,\n",
    "      \"RMSE:\\t\\t\\t%f\" % eval_rmse,\n",
    "      \"MAE:\\t\\t\\t%f\" % eval_mae,\n",
    "      \"Explained variance:\\t%f\" % eval_exp_var,\n",
    "      \"R squared:\\t\\t%f\" % eval_r2, sep='\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That RMSE is competitive in comparison with other models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/notebook_utils.json+json": {
       "data": 0.020839167053536133,
       "encoder": "json",
       "name": "map"
      }
     },
     "metadata": {
      "notebook_utils": {
       "data": true,
       "display": false,
       "name": "map"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/notebook_utils.json+json": {
       "data": 0.13140886626622267,
       "encoder": "json",
       "name": "ndcg"
      }
     },
     "metadata": {
      "notebook_utils": {
       "data": true,
       "display": false,
       "name": "ndcg"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/notebook_utils.json+json": {
       "data": 0.12163308589607637,
       "encoder": "json",
       "name": "precision"
      }
     },
     "metadata": {
      "notebook_utils": {
       "data": true,
       "display": false,
       "name": "precision"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/notebook_utils.json+json": {
       "data": 0.04791248067724805,
       "encoder": "json",
       "name": "recall"
      }
     },
     "metadata": {
      "notebook_utils": {
       "data": true,
       "display": false,
       "name": "recall"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/notebook_utils.json+json": {
       "data": 0.9104563887759987,
       "encoder": "json",
       "name": "rmse"
      }
     },
     "metadata": {
      "notebook_utils": {
       "data": true,
       "display": false,
       "name": "rmse"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/notebook_utils.json+json": {
       "data": 0.7135247603946885,
       "encoder": "json",
       "name": "mae"
      }
     },
     "metadata": {
      "notebook_utils": {
       "data": true,
       "display": false,
       "name": "mae"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/notebook_utils.json+json": {
       "data": 0.3395859668119835,
       "encoder": "json",
       "name": "exp_var"
      }
     },
     "metadata": {
      "notebook_utils": {
       "data": true,
       "display": false,
       "name": "exp_var"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/notebook_utils.json+json": {
       "data": 0.3395634341199626,
       "encoder": "json",
       "name": "rsquared"
      }
     },
     "metadata": {
      "notebook_utils": {
       "data": true,
       "display": false,
       "name": "rsquared"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/notebook_utils.json+json": {
       "data": 63.654639774002135,
       "encoder": "json",
       "name": "train_time"
      }
     },
     "metadata": {
      "notebook_utils": {
       "data": true,
       "display": false,
       "name": "train_time"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Record results for tests - ignore this cell\n",
    "store_metadata(\"map\", eval_map)\n",
    "store_metadata(\"ndcg\", eval_ndcg)\n",
    "store_metadata(\"precision\", eval_precision)\n",
    "store_metadata(\"recall\", eval_recall)\n",
    "store_metadata(\"rmse\", eval_rmse)\n",
    "store_metadata(\"mae\", eval_mae)\n",
    "store_metadata(\"exp_var\", eval_exp_var)\n",
    "store_metadata(\"rsquared\", eval_r2)\n",
    "store_metadata(\"train_time\", train_time.interval)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "recommenders311",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
